{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM (Long Short Term Memory)\n",
    "\n",
    "There is a branch of Deep Learning that is dedicated to processing time series. These deep Nets are **Recursive Neural Nets (RNNs)**. LSTMs are one of the few types of RNNs that are available. Gated Recurent Units (GRUs) are the other type of popular RNNs.\n",
    "\n",
    "This is an illustration from http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (A highly recommended read)\n",
    "\n",
    "![RNNs](../RNN-unrolled.png)\n",
    "\n",
    "### Futher Reading\n",
    "1. http://karpathy.github.io/2015/05/21/rnn-effectiveness/\n",
    "2. http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n",
    "\n",
    "### YouTube Video\n",
    "1. https://www.youtube.com/watch?v=ywinX5wgdEU\n",
    "2. https://www.youtube.com/watch?v=e1pEIYVOtqc\n",
    "\n",
    "Pros:\n",
    "- Really powerful pattern recognition system for time series\n",
    "\n",
    "Cons:\n",
    "- Cannot deal with missing time steps.\n",
    "- Time steps must be discretised and not continuous.\n",
    "\n",
    "![trump](../images/trump.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/e1pEIYVOtqc?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/e1pEIYVOtqc?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import re\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization, LSTM, Embedding, TimeDistributed\n",
    "from keras.models import load_model, model_from_json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def chr2val(ch):\n",
    "    ch = ch.lower()\n",
    "    if ch.isalpha():\n",
    "        return 1 + (ord(ch) - ord('a'))\n",
    "    else:\n",
    "        return 0\n",
    "    \n",
    "def val2chr(v):\n",
    "    if v == 0:\n",
    "        return ' '\n",
    "    else:\n",
    "        return chr(ord('a') + v - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>text</th>\n",
       "      <th>created_at</th>\n",
       "      <th>favorite_count</th>\n",
       "      <th>is_retweet</th>\n",
       "      <th>id_str</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>i think senator blumenthal should take a nice ...</td>\n",
       "      <td>08-07-2017 20:48:54</td>\n",
       "      <td>61446</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946617e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>how much longer will the failing nytimes with ...</td>\n",
       "      <td>08-07-2017 20:39:46</td>\n",
       "      <td>42235</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946594e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>the fake news media will not talk about the im...</td>\n",
       "      <td>08-07-2017 20:15:18</td>\n",
       "      <td>45050</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946532e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>on #purpleheartday i thank all the brave men a...</td>\n",
       "      <td>08-07-2017 18:03:42</td>\n",
       "      <td>48472</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946201e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>...conquests how brave he was and it was all a...</td>\n",
       "      <td>08-07-2017 12:01:20</td>\n",
       "      <td>59253</td>\n",
       "      <td>false</td>\n",
       "      <td>8.945289e+17</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               source                                               text  \\\n",
       "0  Twitter for iPhone  i think senator blumenthal should take a nice ...   \n",
       "1  Twitter for iPhone  how much longer will the failing nytimes with ...   \n",
       "2  Twitter for iPhone  the fake news media will not talk about the im...   \n",
       "4  Twitter for iPhone  on #purpleheartday i thank all the brave men a...   \n",
       "5  Twitter for iPhone  ...conquests how brave he was and it was all a...   \n",
       "\n",
       "            created_at favorite_count is_retweet        id_str  \n",
       "0  08-07-2017 20:48:54          61446      false  8.946617e+17  \n",
       "1  08-07-2017 20:39:46          42235      false  8.946594e+17  \n",
       "2  08-07-2017 20:15:18          45050      false  8.946532e+17  \n",
       "4  08-07-2017 18:03:42          48472      false  8.946201e+17  \n",
       "5  08-07-2017 12:01:20          59253      false  8.945289e+17  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('trump.csv')\n",
    "df = df[df.is_retweet=='false']\n",
    "df.text = df.text.str.lower()\n",
    "df.text = df.text.str.replace(r'http[\\w:/\\.]+','') # remove urls\n",
    "df.text = df.text.str.replace(r'[^!\\'\"#$%&\\()*+,-./:;<=>?@_’`{|}~\\w\\s]',' ') #remove everything but characters and punctuation\n",
    "df.text = df.text.str.replace(r'\\s\\s+',' ') #replace multple white space with a single one\n",
    "df = df[[len(t)<180 for t in df.text.values]]\n",
    "df = df[[len(t)>50 for t in df.text.values]]\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(23938, 6)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remove emojis, flags etc from tweets. Also notice how I have used `[::-1]` to indicate that I want the tweets in chrnological order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['be sure to tune in and watch donald trump on late night with david letterman as he presents the top ten list tonight!',\n",
       " 'donald trump will be appearing on the view tomorrow morning to discuss celebrity apprentice and his new book think like a champion!',\n",
       " 'donald trump reads top ten financial tips on late show with david letterman:  - very funny!',\n",
       " 'new blog post: celebrity apprentice finale and lessons learned along the way: ',\n",
       " 'my persona will never be that of a wallflower - i’d rather build walls than cling to them --donald j. trump']"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trump_tweets = [text for text in df.text.values[::-1]]\n",
    "trump_tweets[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a dictionary to convert letters to numbers and vice versa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_tweets = ''.join(trump_tweets)\n",
    "char2int = dict(zip(set(all_tweets), range(len(set(all_tweets)))))\n",
    "char2int['<END>'] = len(char2int)\n",
    "char2int['<GO>'] = len(char2int)\n",
    "char2int['<PAD>'] = len(char2int)\n",
    "int2char = dict(zip(char2int.values(), char2int.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "text_num = [[char2int['<GO>']]+[char2int[c] for c in tweet]+ [char2int['<END>']] for tweet in trump_tweets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAD8CAYAAACRkhiPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAENFJREFUeJzt3X+snuVdx/H3R7oh++XAdpWVznamMwES2aiITg0TM7rN\nWPbPUqKDRaSL4LKZRS1b4uYfTdjcj4REUCZI0QlpNiZNBjpGpssSgR2Q0R+soY4yWgvtXJSpCQ72\n9Y/nYjw7nNPz+zzn9Hq/kjvP/XzvH+e6aPt8zn3d13OTqkKS1KcfG3UDJEmjYwhIUscMAUnqmCEg\nSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOrZi1A2YysqVK2vdunWjboYkLSsPPPDAd6pq1VT7LfkQ\nWLduHWNjY6NuhiQtK0ken85+DgdJUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOmYI\nSFLHlvw3hiUtHeu2fXHC+sFr3rHILdF88UpAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSO\nGQKS1DFDQJI6ZghIUscMAUnqmCEgSR2bMgSSrE3ylST7kuxN8v5W/2iSw0keasvbh465OsmBJPuT\nXDRUPzfJ7rbt2iRZmG5JkqZjOk8RfRb4YFU9mOSVwANJ7m7bPl1VnxjeOcmZwBbgLOC1wJeTvKGq\nngOuB64A7gPuBDYBd81PVyRJMzXllUBVHamqB9v694BHgDXHOWQzcFtVPVNVjwEHgPOSnA68qqru\nraoCbgEunnMPJEmzNqN7AknWAW9k8Js8wPuSPJzkpiSnttoa4Imhww612pq2Pr4uSRqRaYdAklcA\nnwc+UFVPMxjaeT1wDnAE+OR8NSrJ1iRjScaOHTs2X6eVJI0zrRBI8hIGAfDZqrodoKqeqqrnquoH\nwGeA89ruh4G1Q4ef0WqH2/r4+otU1Q1VtbGqNq5atWom/ZEkzcB0ZgcFuBF4pKo+NVQ/fWi3dwJ7\n2vouYEuSk5OsBzYA91fVEeDpJOe3c14K3DFP/ZAkzcJ0Zge9GXg3sDvJQ632IeCSJOcABRwE3gtQ\nVXuT7AT2MZhZdFWbGQRwJXAzcAqDWUHODJKkEZoyBKrqa8BE8/nvPM4x24HtE9THgLNn0kBJ0sLx\nG8OS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pgh\nIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS\n1DFDQJI6ZghIUscMAUnqmCEgSR2bMgSSrE3ylST7kuxN8v5WPy3J3Ukeba+nDh1zdZIDSfYnuWio\nfm6S3W3btUmyMN2SJE3HdK4EngU+WFVnAucDVyU5E9gG3FNVG4B72nvati3AWcAm4LokJ7VzXQ9c\nAWxoy6Z57IskaYamDIGqOlJVD7b17wGPAGuAzcCOttsO4OK2vhm4raqeqarHgAPAeUlOB15VVfdW\nVQG3DB0jSRqBGd0TSLIOeCNwH7C6qo60TU8Cq9v6GuCJocMOtdqatj6+LkkakWmHQJJXAJ8HPlBV\nTw9va7/Z13w1KsnWJGNJxo4dOzZfp5UkjTOtEEjyEgYB8Nmqur2Vn2pDPLTXo61+GFg7dPgZrXa4\nrY+vv0hV3VBVG6tq46pVq6bbF0nSDE1ndlCAG4FHqupTQ5t2AZe19cuAO4bqW5KcnGQ9gxvA97eh\no6eTnN/OeenQMZKkEVgxjX3eDLwb2J3koVb7EHANsDPJ5cDjwLsAqmpvkp3APgYzi66qqufacVcC\nNwOnAHe1RZI0IlOGQFV9DZhsPv+FkxyzHdg+QX0MOHsmDZQkLRy/MSxJHTMEJKljhoAkdcwQkKSO\nGQKS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pgh\nIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS\n1LEpQyDJTUmOJtkzVPtoksNJHmrL24e2XZ3kQJL9SS4aqp+bZHfbdm2SzH93JEkzMZ0rgZuBTRPU\nP11V57TlToAkZwJbgLPaMdclOantfz1wBbChLROdU5K0iKYMgar6KvDdaZ5vM3BbVT1TVY8BB4Dz\nkpwOvKqq7q2qAm4BLp5toyVJ82Mu9wTel+ThNlx0aqutAZ4Y2udQq61p6+PrkqQRmm0IXA+8HjgH\nOAJ8ct5aBCTZmmQsydixY8fm89SSpCGzCoGqeqqqnquqHwCfAc5rmw4Da4d2PaPVDrf18fXJzn9D\nVW2sqo2rVq2aTRMlSdMwqxBoY/zPeyfw/MyhXcCWJCcnWc/gBvD9VXUEeDrJ+W1W0KXAHXNotyRp\nHqyYaocktwIXACuTHAI+AlyQ5ByggIPAewGqam+SncA+4Fngqqp6rp3qSgYzjU4B7mqLJGmEpgyB\nqrpkgvKNx9l/O7B9gvoYcPaMWidJWlB+Y1iSOmYISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4Z\nApLUsSm/LCZpZtZt++KE9YPXvGORWyJNzSsBSeqYISBJHXM4SOqYQ1fySkCSOmYISFLHDAFJ6pgh\nIEkdMwQkqWPODpL0IpPNGtKJxysBSeqYISBJHXM4SNKc+aWz5csrAUnqmCEgSR0zBCSpY4aAJHXM\nEJCkjhkCktQxQ0CSOmYISFLHpgyBJDclOZpkz1DttCR3J3m0vZ46tO3qJAeS7E9y0VD93CS727Zr\nk2T+uyNJmonpXAncDGwaV9sG3FNVG4B72nuSnAlsAc5qx1yX5KR2zPXAFcCGtow/pyRpkU0ZAlX1\nVeC748qbgR1tfQdw8VD9tqp6pqoeAw4A5yU5HXhVVd1bVQXcMnSMJGlEZntPYHVVHWnrTwKr2/oa\n4Imh/Q612pq2Pr4uSRqhOd8Ybr/Z1zy05YeSbE0ylmTs2LFj83lqSdKQ2YbAU22Ih/Z6tNUPA2uH\n9juj1Q639fH1CVXVDVW1sao2rlq1apZNlCRNZbYhsAu4rK1fBtwxVN+S5OQk6xncAL6/DR09neT8\nNivo0qFjJEkjMuX/TyDJrcAFwMokh4CPANcAO5NcDjwOvAugqvYm2QnsA54Frqqq59qprmQw0+gU\n4K62SJJGaMoQqKpLJtl04ST7bwe2T1AfA86eUeskSQvKbwxLUscMAUnqmCEgSR0zBCSpY4aAJHVs\nytlBUi/WbfvihPWD17xjkVsiLR6vBCSpY4aAJHXMEJCkjnlPQJqlye4hzNd5vBehxeCVgCR1zBCQ\npI4ZApLUMUNAkjpmCEhSx5wdJC2S+ZpNdLxzOaNIM2UI6ITlB6U0NYeDJKljhoAkdczhIC0bDu9I\n888rAUnqmFcC0hTmc1aPtNQYAlIHDDJNxuEgSeqYISBJHXM4SAvKGT3S0uaVgCR1zBCQpI45HCSd\nQJwFpJma05VAkoNJdid5KMlYq52W5O4kj7bXU4f2vzrJgST7k1w018ZLkuZmPoaD3lJV51TVxvZ+\nG3BPVW0A7mnvSXImsAU4C9gEXJfkpHn4+ZKkWVqI4aDNwAVtfQfwT8Aft/ptVfUM8FiSA8B5wL8s\nQBu0QJztI51Y5hoCBXw5yXPAX1bVDcDqqjrStj8JrG7ra4B7h4491GpaBH54S5rIXEPgl6vqcJLX\nAHcn+ebwxqqqJDXTkybZCmwFeN3rXjfHJko/ypun0gvmdE+gqg6316PAFxgM7zyV5HSA9nq07X4Y\nWDt0+BmtNtF5b6iqjVW1cdWqVXNpoiTpOGYdAklenuSVz68DbwX2ALuAy9pulwF3tPVdwJYkJydZ\nD2wA7p/tz5ckzd1choNWA19I8vx5/q6q/iHJ14GdSS4HHgfeBVBVe5PsBPYBzwJXVdVzc2r9CcQx\n+9lzeEeavVmHQFV9C/i5Cer/AVw4yTHbge2z/Zl6wVILDT+I55//TbUY/MawRsIPOGlp8NlBktQx\nQ0CSOuZw0BwstXF5SZoprwQkqWOGgCR1zOEgTcjZO1IfDIEhjvFL6o0h0Dl/45f6ZgicYPxQlzQT\nhsAS54e6pIV0QoeAY/ySdHwndAgsNf5WL2mpMQQWgB/2kpYLQ2Aa/FCXdKIyBCQtGO/LLX0+NkKS\nOuaVgKRF5xXC0tFlCDjGL0kDDgdJUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOmYI\nSFLHDAFJ6tiih0CSTUn2JzmQZNti/3xJ0gsWNQSSnAT8OfA24EzgkiRnLmYbJEkvWOwHyJ0HHKiq\nbwEkuQ3YDOxb5HZIWoJ8uujiW+wQWAM8MfT+EPALi9wGScuM4bBwluSjpJNsBba2t/+dZP8o2zOJ\nlcB3Rt2IeWA/lpYTpR+wCH3Jxxby7D+0XP9Mfno6Oy12CBwG1g69P6PVfkRV3QDcsFiNmo0kY1W1\ncdTtmCv7sbScKP2AE6cvJ0o/JrPYs4O+DmxIsj7JS4EtwK5FboMkqVnUK4GqejbJ7wP/CJwE3FRV\nexezDZKkFyz6PYGquhO4c7F/7gJY0sNVM2A/lpYTpR9w4vTlROnHhFJVo26DJGlEfGyEJHXMEJiG\nJK9O8rkk30zySJJfTHJakruTPNpeTx11O6eS5A+S7E2yJ8mtSX58ufQjyU1JjibZM1SbtO1Jrm6P\nJtmf5KLRtPrFJunHn7W/Ww8n+UKSVw9tWzb9GNr2wSSVZOVQbUn2AybvS5L3tT+XvUk+PlRfsn2Z\nlapymWIBdgC/29ZfCrwa+DiwrdW2AR8bdTun6MMa4DHglPZ+J/Ce5dIP4FeBNwF7hmoTtp3BI0m+\nAZwMrAf+DThp1H04Tj/eCqxo6x9brv1o9bUMJn48Dqxc6v04zp/JW4AvAye3969ZDn2ZzeKVwBSS\n/ASDvyQ3AlTV/1XVfzJ43MWOttsO4OLRtHBGVgCnJFkBvAz4d5ZJP6rqq8B3x5Una/tm4Laqeqaq\nHgMOMHhkychN1I+q+lJVPdve3svg+zOwzPrRfBr4I2D4ZuOS7QdM2pffA66pqmfaPkdbfUn3ZTYM\ngamtB44Bf53kX5P8VZKXA6ur6kjb50lg9chaOA1VdRj4BPBt4AjwX1X1JZZZP8aZrO0TPZ5kzWI2\nbA5+B7irrS+rfiTZDByuqm+M27Ss+tG8AfiVJPcl+eckP9/qy7Evx2UITG0Fg0vF66vqjcD/MBh6\n+KEaXCcu6WlWbbx8M4NQey3w8iS/PbzPcujHZJZz25+X5MPAs8BnR92WmUryMuBDwJ+Mui3zZAVw\nGnA+8IfAziQZbZMWhiEwtUPAoaq6r73/HINQeCrJ6QDt9egkxy8Vvw48VlXHqur7wO3AL7H8+jFs\nsrZP6/EkS0mS9wC/AfxWCzRYXv34GQa/YHwjyUEGbX0wyU+xvPrxvEPA7TVwP/ADBs8QWo59OS5D\nYApV9STwRJKfbaULGTz6ehdwWatdBtwxgubNxLeB85O8rP1GcyHwCMuvH8Mma/suYEuSk5OsBzYA\n94+gfdOSZBODcfTfrKr/Hdq0bPpRVbur6jVVta6q1jH4EH1T+/ezbPox5O8Z3BwmyRsYTAj5Dsuz\nL8c36jvTy2EBzgHGgIcZ/OU4FfhJ4B7gUQazCE4bdTun0Y8/Bb4J7AH+hsEMh2XRD+BWBvcyvs/g\nA+by47Ud+DCDmRv7gbeNuv1T9OMAg3Hmh9ryF8uxH+O2H6TNDlrK/TjOn8lLgb9t/1YeBH5tOfRl\nNovfGJakjjkcJEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSerY/wNaeyYNKBoJXgAA\nAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f429b47d550>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist([len(t) for t in trump_tweets],50)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Concatenate all the tweets\n",
    "int_text = []\n",
    "for t in text_num:\n",
    "    int_text += t\n",
    "\n",
    "len_vocab = len(char2int)\n",
    "sentence_len = 40\n",
    "# n_chars = len(text_num)//sentence_len*sentence_len\n",
    "num_chunks = len(text_num)-sentence_len\n",
    "\n",
    "def get_batches(int_text, batch_size, seq_length):\n",
    "    \"\"\"\n",
    "    Return batches of input and target\n",
    "    :param int_text: Text with the words replaced by their ids\n",
    "    :param batch_size: The size of batch\n",
    "    :param seq_length: The length of sequence\n",
    "    :return: Batches as a Numpy array\n",
    "    \"\"\"\n",
    "    \n",
    "    slice_size = batch_size * seq_length\n",
    "    n_batches = len(int_text) // slice_size\n",
    "    x = int_text[: n_batches*slice_size]\n",
    "    y = int_text[1: n_batches*slice_size + 1]\n",
    "\n",
    "    x = np.split(np.reshape(x,(batch_size,-1)),n_batches,1)\n",
    "    y = np.split(np.reshape(y,(batch_size,-1)),n_batches,1)\n",
    "    \n",
    "    x = np.vstack(x)\n",
    "    y = np.vstack(y)\n",
    "    y = y.reshape(y.shape+(1,))\n",
    "    return x, y\n",
    "\n",
    "batch_size = 128\n",
    "x, y = get_batches(int_text, batch_size, sentence_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice what the `get_batches` function looks like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0,  1,  2,  3],\n",
       "       [ 4,  5,  6,  7],\n",
       "       [ 8,  9, 10, 11],\n",
       "       [12, 13, 14, 15]])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.arange(16).reshape((-1,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.arange(1,17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0,  1,  2,  3],\n",
       "       [ 8,  9, 10, 11],\n",
       "       [ 4,  5,  6,  7],\n",
       "       [12, 13, 14, 15]])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1, y1 = get_batches(np.arange(17), 2,4)\n",
    "x1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Many to Many LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_2 (Embedding)      (128, None, 16)           1360      \n",
      "_________________________________________________________________\n",
      "lstm_3 (LSTM)                (128, None, 64)           20736     \n",
      "_________________________________________________________________\n",
      "lstm_4 (LSTM)                (128, None, 64)           33024     \n",
      "_________________________________________________________________\n",
      "time_distributed_2 (TimeDist (128, None, 85)           5525      \n",
      "=================================================================\n",
      "Total params: 60,645.0\n",
      "Trainable params: 60,645\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(len_vocab, 16, batch_size=batch_size)) # , batch_size=batch_size\n",
    "model.add(LSTM(64, return_sequences=True, stateful=True)) # , stateful=True\n",
    "model.add(LSTM(64, return_sequences=True, stateful=True))\n",
    "model.add(TimeDistributed(Dense(len_vocab, activation='softmax')))\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pay special attention to how the probabilites are taken. p is of shape `(1, sequence_len, len(char2int))` where len(char2int) is the number of available characters. The 1 is there because we are only predicting one feature, `y`. We are only concerned about the last prediction probability of the sequence. This is due to the fact that all other letters have already been appended. Hence we predict a letter from the distribution `p[0][-1]`.\n",
    "\n",
    "Why did we keep appending to the sequence and predicting? Why not use simply the last letter. If we were to do this, we would lose information that comes from the previous letter via the hidden state and cell memory. Keep in mind that each LSTM unit has 3 inputs, the x, the hidden state, and the cell memory. \n",
    "\n",
    "Also important to notice that the Cell Memory is not used in connecting to the Dense layer, only the hidden state.\n",
    "\n",
    "### Stateful training:\n",
    "\n",
    "What happens when `stateful=True` is that the last cell memory state computed at the i-th example in the n-th batch gets passed on to i-th sample in the n+1-th batch. This is one way of _seeing_ patterns beyond the sentence length specified. Which in this case is 40.\n",
    "\n",
    "#### Note\n",
    "1. Really important that when I `fit` the model I set `shuffle=False` when training stateful models.\n",
    "2. I need to copy the weights to a non-stateful model because the original only takes `batch_size` inputs at a time. Whereas, I only want to predict one character at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model2 = Sequential()\n",
    "model2.add(Embedding(len_vocab, 16)) # , batch_size=batch_size\n",
    "model2.add(LSTM(64, return_sequences=True)) # , stateful=True\n",
    "model2.add(LSTM(64, return_sequences=True))\n",
    "model2.add(TimeDistributed(Dense(len_vocab, activation='softmax')))\n",
    "\n",
    "def generate_sentence(model, sentence_len):\n",
    "    sentence = []\n",
    "    letter = [char2int['<GO>']] #choose a random letter\n",
    "    for i in range(sentence_len):\n",
    "        sentence.append(int2char[letter[-1]])\n",
    "        model2.set_weights(model.get_weights())\n",
    "        p = model2.predict(np.array(letter)[None,:])\n",
    "        letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "    \n",
    "    return ''.join(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>\n",
      "1/@ıbp. pet ocl  u1pn<END><GO><GO>grbaltlrdbt marapueanos taánaaoe @ratodec.ruto- pfe eamfodhyt ia. bpaldewr\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 111s - loss: 2.5847   \n",
      "<GO>.!- <GO>whingment do3% be @banemviqrush: dotarnion.<END><GO>hill diir. baden\" reations bimman in @realdonactr\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 119s - loss: 1.9990   \n",
      "<GO>\"! @myjinnalnjats: @reealebbparriu<END><GO>bod exbissed great!<END><GO>@erutbzart: i duserest got!u deald reangne\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 104s - loss: 1.8533   \n",
      "<GO><GO>onye right\" #debloss chinastivate!<END><GO>@marmervalf: @realdonaldtrump @realdonaldtrump is many had in \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 114s - loss: 1.7786   \n",
      "<GO> will never newed wid - will filt unfell! @neme not's ambess in @barepelistber the arith 4 milly wo\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "  256/66560 [..............................] - ETA: 200s - loss: 1.8051"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.6/site-packages/keras/callbacks.py:118: UserWarning: Method on_batch_end() is slow compared to the batch update (0.392488). Check your callbacks.\n",
      "  % delta_t_median)\n",
      "/root/miniconda3/lib/python3.6/site-packages/keras/callbacks.py:118: UserWarning: Method on_batch_end() is slow compared to the batch update (0.196978). Check your callbacks.\n",
      "  % delta_t_median)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "66560/66560 [==============================] - 110s - loss: 1.7286   \n",
      "<GO> then by to @foxnews!<END><GO>gold clue onlini nerdities be is lost plepit will i cangon storogts. us this\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 98s - loss: 1.6912    \n",
      "<GO>cnank tourstes the news. be ragelil can valls at it that menting who wants ny's going of for for to\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 105s - loss: 1.6631   \n",
      "<GO>attodays ware @mbh starmentbe you-ene are ever with mitt open stmentuled and i housnist are  <END><GO> my \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 105s - loss: 1.6400   \n",
      "<GO>min balm!<END><GO>@am584  tremince have angeriet’s #donitechug for me insredent!<END><GO>@moon_ter: @newyontrrud \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 111s - loss: 1.6217   \n",
      "<GO>arelaces into kay for i kavity that's kastests i intervied you hand their never cares--he mind gavi\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 93s - loss: 1.6061    \n",
      "<GO>muntwed specter ivank @realdonaldtrump<END><GO>@prani7: @tvs unmit in is the kenentionment thank win in mo\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.5930   \n"
     ]
    }
   ],
   "source": [
    "n_epochs = 50\n",
    "for i in range(n_epochs+1):\n",
    "    if i%5==0:\n",
    "        sentence = generate_sentence(model, 100)\n",
    "        print(sentence)\n",
    "        print('='*100)\n",
    "        v = 1\n",
    "    if i!=n_epochs:\n",
    "        model.fit(x,y, batch_size=batch_size, epochs=1, shuffle=False, verbose=v)\n",
    "        model.reset_states()\n",
    "        v = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>mm thinks @beinkintard.  polio deliveds.<END><GO>@jenaorisemila: @realdanasura   i will resting release! g\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 81s - loss: 1.5906    \n",
      "<GO><END><GO>efstinood-thwe see shout  .@endilaryone @esorropsec.is. hosted nowinutionais beremespess leader t\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 101s - loss: 1.5798   \n",
      "<GO>-&ttrump her your foin never they're vie because can't beaters problems of durmy for put by r shour\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 101s - loss: 1.5706   \n",
      "<GO>mumm2p @foxnewderated<END><GO>poleds..... is.<END><GO>quick barning twee who need o negether mitt at of teldmin! \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 99s - loss: 1.5623    \n",
      "<GO>erear god an if instoiense for my fars!<END><GO>jiaded what be a greatest nomight of my @bofhaptroskec vot\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 104s - loss: 1.5549   \n",
      "<GO>6<GO>sterry gong about the donoys like a look obamacare incleat obama 4/aw missiase thin carness more \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 96s - loss: 1.5485    \n",
      "<GO>be strit sees lead!!<END><GO>@kahpeallellma @cnnbetrook<END><GO>@heacpchi: we must patin. loved stme\" us and pers\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 30\n",
    "for i in range(n_epochs+1):\n",
    "    if i%5==0:\n",
    "        sentence = generate_sentence(model, 100)\n",
    "        print(sentence)\n",
    "        print('='*100)\n",
    "        v = 1\n",
    "    if i!=n_epochs:\n",
    "        model.fit(x,y, batch_size=batch_size, epochs=1, shuffle=False, verbose=v)\n",
    "        model.reset_states()\n",
    "        v = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>8<GO>bai condire @mittromney<END><GO>id more in the majoriful!<END><GO>geored this ressious &amp; worth-leaking me<END><GO>\n",
      "====================================================================================================\n",
      "<GO><END>61.; use stngorfur.<END><GO>thank you bring failed for on 48 for patter<END><GO>@tyoupcortani-   will be good at\n",
      "====================================================================================================\n",
      "<GO>6<GO>i q.. <END><GO>itona is both proobanf invised hin.<END><GO>qoldrew confiden statrazy etheric.. new with priwide\n",
      "====================================================================================================\n",
      "<GO>e nevotu something proshious will sthere.<END><GO>great i jackets anoscorday.<END><GO>very concord for nice! come\n",
      "====================================================================================================\n",
      "<GO>#tutlamprices in weicle fuborselo's governed anceir ffratel...  never saying invelueters don't doin\n",
      "====================================================================================================\n",
      "<GO>5<GO>stust to kufsline for hnited and celebrity watch after havo stet everybt wrate done car-i to peri\n",
      "====================================================================================================\n",
      "<GO>:. \"erusheek lave the story in the drodia in taking waiting foot a man snitiqutes for the man busis\n",
      "====================================================================================================\n",
      "<GO>6<GO>bl. @realdonaldtrump linning. he estafus to for a states and devansel we nominal rynese he contas\n",
      "====================================================================================================\n",
      "<GO>6<GO>stust!<END><GO>@wxejjonaticore: well rating about - an eligason #newssmit for you course listugaainess i\n",
      "====================================================================================================\n",
      "<GO><END>muman tax stwoteers cne\" rivatcold i midely!<END><GO>i'd procean to see you roweral @foxnews to save your\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "for j in range(10):\n",
    "    sentence = generate_sentence(model, 100)\n",
    "    print(sentence)\n",
    "    print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('model_struct.json','w') as f:\n",
    "    f.write(model.to_json())\n",
    "model.save_weights('model_weights.h5')\n",
    "model.save('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# if not 'model' in vars():\n",
    "# #     model = load_model('model.h5') # This doesn't seem to work for some odd reason\n",
    "#     with open('model_struct.json','r') as f:\n",
    "#         model = model_from_json(f.read())\n",
    "#     model.load_weights('model_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_4 (Embedding)      (None, None, 64)          8512      \n",
      "_________________________________________________________________\n",
      "lstm_7 (LSTM)                (None, None, 64)          33024     \n",
      "_________________________________________________________________\n",
      "time_distributed_4 (TimeDist (None, None, 133)         8645      \n",
      "=================================================================\n",
      "Total params: 50,181.0\n",
      "Trainable params: 50,181\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(len_vocab, 64)) # , batch_size=batch_size\n",
    "model.add(LSTM(64, return_sequences=True)) # , stateful=True\n",
    "model.add(TimeDistributed(Dense(len_vocab, activation='softmax')))\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()\n",
    "model.load_weights('model_weights.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Beam Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 2, 4])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(np.random.choice(5,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def beam_search(seq_len):\n",
    "    p = model.predict(np.array(letter)[None,:])\n",
    "    np.unique(np.random.choice(len(char2int),10,p=p[0][-1])[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
